//
// Created by llp on 2020/12/14.
//
#include "pmemrep.h"
#include "pure_mem/rangearena/range_arena_rebuild.h"
#include "pure_mem/rangearena/range_arena.h"
#include "util/random.h"
#include "util/testharness.h"

namespace rocksdb {
    typedef uint64_t Key;

    static const char* Encode(const Key* key) {
      std::string str = std::to_string(*key);
      int len = str.length() + 8 + 1;
      uint32_t keySize = VarintLength(len) + len;
      char* ret = new char[keySize];
      memset(ret , '\0', keySize);
      char*p = EncodeVarint32(ret, len);
      memcpy(p, str.c_str(), str.length());
      return reinterpret_cast<const char*>(ret);
    }

    static Slice Decode(const char* key) {
      Slice buf = GetLengthPrefixedSlice(key);
      Slice mvccKey = Slice(buf.data(), buf.size() - 8);
      const char ts_size = mvccKey[mvccKey.size() - 1];
      if ((size_t)ts_size >= mvccKey.size()) {
        return 0;
      }
      return Slice(buf.data(), mvccKey.size() - 1 - ts_size);
    }

    class TestKeyComparator : public MemTableRep::KeyComparator{

    public:

        DecodedType decode_key(const char* key) const override {
          return Decode(key);
        }

        int operator()(const char* prefix_len_key1,
                               const char* prefix_len_key2) const override {
          return Decode(prefix_len_key1).compare(Decode(prefix_len_key2));
        }

        int operator()(const char* prefix_len_key,
                               const Slice& key) const override {
          return Decode(prefix_len_key).compare(key);
        }

        ~TestKeyComparator() override {}
    };

    typedef PureMemRep TestPureMemRep;
    const size_t standard_key = 1000000;

    class PureMemRepTest : public testing::Test {
    public:
        void Insert(TestPureMemRep* list, Key key) {
          char** buf = nullptr;
          void* rangeArena;
          const Slice& keyy = (const Slice&) key;
          const size_t len = sizeof(Key);
          *buf = (char*)list->AllocatePure(len, buf, keyy, &rangeArena);
          memcpy(*buf, &key, sizeof(Key));
          void* buff = *buf;
          list->Insert(buff);
        }

        void Validate(TestPureMemRep* list) {
          // Check keys exist.
          for (Key key : keys_) {
            ASSERT_TRUE(list->Contains(Encode(&key)));
          }
          // Iterate over the list, make sure keys appears in order and no extra
          // keys exist.
          InlineUserKeyIndex<const MemTableRep::KeyComparator&>::Iterator iter =
              (InlineUserKeyIndex<const MemTableRep::KeyComparator&>::Iterator)
              reinterpret_cast<const InlineUserKeyIndex<const MemTableRep::KeyComparator &> *>(list->GetIterator());
          ASSERT_FALSE(iter.Valid());
          Key zero = 0;
          iter.Seek(Encode(&zero));
          for (Key key : keys_) {
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ((char*)key, Decode(iter.key()).data());
            iter.Next();
          }
          ASSERT_FALSE(iter.Valid());
        }

    private:
        std::set<Key> keys_;
    };

    TEST_F(PureMemRepTest, Empty) {
      size_t lookahead = 0;
      Allocator *allocator = nullptr;
      const TestKeyComparator compare;
      const SliceTransform *transform = nullptr;
      PureMemFactory factory;
      Logger* log = nullptr;
      PureMemRep *list = new PureMemRep(compare, allocator, transform, lookahead);
      Key key = 10;
      ASSERT_TRUE(!((PureMemRep*)list)->Contains(Encode(&key)));
      PureMemRep::Iterator iter(list->GetARTList());
      ASSERT_TRUE(!iter.Valid());
      iter.SeekToFirst();
      ASSERT_TRUE(!iter.Valid());
      key = 100;
      iter.Seek(nullptr, Encode(&key));
      ASSERT_TRUE(!iter.Valid());
      iter.SeekForPrev(nullptr, Encode(&key));
      ASSERT_TRUE(!iter.Valid());
      iter.SeekToLast();
      ASSERT_TRUE(!iter.Valid());

}


TEST_F(PureMemRepTest, InsertAndLookup) {
      const int N = 2000;
      const int R = 5000;
      Random rnd(1000);
      std::set<Key> keys;
      size_t lookahead = 0;
      Allocator *allocator = nullptr;
      const TestKeyComparator compare;
      const SliceTransform *transform = nullptr;
      void* rangearena = nullptr;
      PureMemRep *list = new PureMemRep(compare, allocator, transform, lookahead);
      for (int i = 0; i < N; i++) {
        Key key = rnd.Next() % R + standard_key;

        std::string str = std::to_string(key);
        int len = str.length() + 8 + 1;
        uint32_t keySize = VarintLength(len) + len;
        const char* encode_key = Encode(&key);
        Slice userkey = Decode(encode_key);

        if (keys.insert(key).second) {
          char* buf = nullptr;
          void* handle = ((PureMemRep *) list)->AllocatePure(keySize, &buf, userkey, &rangearena);
          memcpy(buf, encode_key, keySize);
          list->Insert(handle);
          list->AllocateOK(userkey, keySize, handle, rangearena);
      }
        PureMemRep::Iterator iter(list->GetARTList());
        ASSERT_TRUE(!iter.Valid());
        uint64_t zero = 0;
        iter.Seek(nullptr, Encode(&zero));
        ASSERT_TRUE(iter.Valid());
    }


    for (Key i = 0; i < R; i++) {

      std::string str = std::to_string(i);
      int len = str.length() + 8 + 1;
      uint32_t keySize = VarintLength(len) + len;

      Slice kkey(Encode(&i), keySize);
      if (list->Contains(kkey.data())) {
        ASSERT_EQ(keys.count(i), 1);
      } else {
        ASSERT_EQ(keys.count(i), 0);
      }
    }
    // Simple iterator tests
    {
      PureMemRep::Iterator iter(list->GetARTList());
      ASSERT_TRUE(!iter.Valid());
      uint64_t zero = 0;
      iter.Seek(nullptr, Encode(&zero));
      ASSERT_TRUE(iter.Valid());

      ASSERT_EQ(std::to_string(*(keys.begin())), (Decode(iter.key()).data()));
      uint64_t max_key = R - 1;
      iter.SeekForPrev(nullptr, Encode(&max_key));
      ASSERT_TRUE(iter.Valid());
      ASSERT_EQ(std::to_string(*(keys.rbegin())), (Decode(iter.key()).data()));
      iter.SeekToFirst();
      ASSERT_TRUE(iter.Valid());
      ASSERT_EQ(std::to_string(*(keys.begin())), (Decode(iter.key()).data()));
      iter.SeekToLast();
      ASSERT_TRUE(iter.Valid());
      ASSERT_EQ(std::to_string(*(keys.rbegin())), (Decode(iter.key()).data()));
    }

    // Forward iteration test
    for (Key i = standard_key; i < standard_key + R; i++) {
      PureMemRep::Iterator iter(list->GetARTList());
      iter.Seek(nullptr, Encode(&i));
      // Compare against model iterator
      std::set<Key>::iterator model_iter = keys.lower_bound(i);
      for (int j = 0; j < 3; j++) {
        if (model_iter == keys.end()) {
          ASSERT_TRUE(!iter.Valid());
          break;
        } else {
          ASSERT_TRUE(iter.Valid());
          ASSERT_EQ(std::to_string(*model_iter), (Decode(iter.key()).data()));
          ++model_iter;
          iter.Next();
        }
      }
    }
    // Backward iteration test
    for (Key i = standard_key; i < standard_key + R; i++) {
      PureMemRep::Iterator iter(list->GetARTList());

      iter.SeekForPrev(nullptr, Encode(&i));

      // Compare against model iterator
      std::set<Key>::iterator model_iter = keys.upper_bound(i);
      for (int j = 0; j < 3; j++) {
        if (model_iter == keys.begin()) {
          ASSERT_TRUE(!iter.Valid());
          break;
        } else {
          ASSERT_TRUE(iter.Valid());
          ASSERT_EQ(std::to_string(*--model_iter), (Decode(iter.key()).data()));
          iter.Prev();
        }
      }
    }
  }

TEST_F(PureMemRepTest, HashIterator) {
        const int N = 2000;
        const int R = 5000;
        Random rnd(1000);
        std::set<Key> keys;
        size_t lookahead = 0;
        Allocator *allocator = nullptr;
        const TestKeyComparator compare;
        const SliceTransform *transform = nullptr;
        void* rangearena = nullptr;
        PureMemRep *list = new PureMemRep(compare, allocator, transform, lookahead);
        for (int i = 0; i < N; i++) {
            Key key = rnd.Next() % R + standard_key;

            std::string str = std::to_string(key);
            int len = str.length() + 8 + 1;
            uint32_t keySize = VarintLength(len) + len;
            const char* encode_key = Encode(&key);
            Slice userkey = Decode(encode_key);

            if (keys.insert(key).second) {
                char* buf = nullptr;
                void* handle = ((PureMemRep *) list)->AllocatePure(keySize, &buf, userkey, &rangearena);
                memcpy(buf, encode_key, keySize);
                list->Insert(handle);
                list->AllocateOK(userkey, keySize, handle, rangearena);
            }
            PureMemRep::HashIterator iter(list->GetARTList());
            ASSERT_TRUE(!iter.Valid());
            uint64_t zero = 0;
            iter.Seek(nullptr, Encode(&zero));
            ASSERT_TRUE(iter.Valid());
        }


        for (Key i = 0; i < R; i++) {

            std::string str = std::to_string(i);
            int len = str.length() + 8 + 1;
            uint32_t keySize = VarintLength(len) + len;

            Slice kkey(Encode(&i), keySize);
            if (list->Contains(kkey.data())) {
                ASSERT_EQ(keys.count(i), 1);
            } else {
                ASSERT_EQ(keys.count(i), 0);
            }
        }
        // Simple iterator tests
        {
            PureMemRep::HashIterator iter(list->GetARTList());
            ASSERT_TRUE(!iter.Valid());
            uint64_t zero = 0;
            iter.Seek(nullptr, Encode(&zero));
            ASSERT_TRUE(iter.Valid());

            ASSERT_EQ(std::to_string(*(keys.begin())), (Decode(iter.key()).data()));
            uint64_t max_key = R - 1;
            iter.SeekForPrev(nullptr, Encode(&max_key));
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ(std::to_string(*(keys.rbegin())), (Decode(iter.key()).data()));
            iter.SeekToFirst();
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ(std::to_string(*(keys.begin())), (Decode(iter.key()).data()));
            iter.SeekToLast();
            ASSERT_TRUE(iter.Valid());
            ASSERT_EQ(std::to_string(*(keys.rbegin())), (Decode(iter.key()).data()));
        }

        // Forward iteration test
        for (Key i = standard_key; i < standard_key + R; i++) {
            PureMemRep::HashIterator iter(list->GetARTList());
            iter.Seek(nullptr, Encode(&i));
            // Compare against model iterator
            std::set<Key>::iterator model_iter = keys.lower_bound(i);
            for (int j = 0; j < 3; j++) {
                if (model_iter == keys.end()) {
                    ASSERT_TRUE(!iter.Valid());
                    break;
                } else {
                    ASSERT_TRUE(iter.Valid());
                    ASSERT_EQ(std::to_string(*model_iter), (Decode(iter.key()).data()));
                    ++model_iter;
                    iter.Next();
                }
            }
        }
        // Backward iteration test
        for (Key i = standard_key; i < standard_key + R; i++) {
            PureMemRep::HashIterator iter(list->GetARTList());

            iter.SeekForPrev(nullptr, Encode(&i));

            // Compare against model iterator
            std::set<Key>::iterator model_iter = keys.upper_bound(i);
            for (int j = 0; j < 3; j++) {
                if (model_iter == keys.begin()) {
                    ASSERT_TRUE(!iter.Valid());
                    break;
                } else {
                    ASSERT_TRUE(iter.Valid());
                    ASSERT_EQ(std::to_string(*--model_iter), (Decode(iter.key()).data()));
                    iter.Prev();
                }
            }
        }
    }
}  //namespace rocksdb

int main(int argc, char** argv) {
  ::testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}